import torch.nn as nn
import torch


class Parallel(nn.Module):
    def __init__(self, net1, net2):
        super().__init__()
        self.net1=net1 # 第一个子网络
        self.net2=net2 # 第二个子网络

    def forward(self,X):
        x1= self.net1(X) # 第一个子网络的输出
        x2= self.net2(X) # 第二个子网络的输出
        return torch.cat((x1,x2),dim=1) # 在通道维度上合并输出结果

X = torch.rand(2,10) # 输入数据
net = Parallel(nn.Sequential(nn.Linear(10,12),nn.ReLU()), nn.Sequential(nn.Linear(10,24),nn.ReLU())) # 实例化并行网络
output = net(X)
print(output.size()) # 输出结果的大小